library(cregg)
library(dplyr)
library(ggplot2)
library(stringr)
library(formula.tools)
library(stringr)
library(patchwork)
library(tikzDevice)


fix_duplicate_level_names <- function(cregg_function_input, formula_to_be_used){
  # the cregg package does not accept attributes with the same level name. 
  # this function temporarily adds a special character to the duplicate level names
  
  variables_to_be_used <- rhs.vars(formula_to_be_used)
  variables_to_be_used <- str_remove_all(variables_to_be_used, '`')
  
  variables_used_in_conjoint <- dplyr::select(survey_data, all_of(variables_to_be_used))
  
  observed_level_names <- c()
  
  for(variable in variables_to_be_used){
    
    attribute <- variables_used_in_conjoint[, variable]
    level_names <- levels(attribute)
    
    level_names_in_other_attribute_mask <- level_names %in% observed_level_names
    level_names_in_other_attribute <- level_names[level_names_in_other_attribute_mask]
    
    if (length(level_names_in_other_attribute) > 0) {

      corrected_name <- paste0(level_names_in_other_attribute, "@")
      names(corrected_name) <- level_names_in_other_attribute
      
      attribute <- recode_factor(attribute, !!!corrected_name)
      
    }
    
    previously_unobserved_lvl_names <- level_names[!level_names_in_other_attribute_mask]
    observed_level_names <- c(observed_level_names, previously_unobserved_lvl_names)
    
    survey_data[, variable] <- attribute  
    
  }
  
  return(survey_data)
}

fix_level_name_in_cregg_output <- function(cregg_output){
  
  # removes the special character added in fix_duplicate_level_names
  cregg_output <- cregg_output %>% 
    mutate(level = stringr::str_remove_all(string = as.character(level), pattern = "@"))
  cregg_output <- cregg_output %>% 
    mutate(level = as.factor(level))
}


make_vote_choice_plot <- function(cregg_output, 
                                  show_only_relevant_factors = TRUE){
  # function to make the paper's plot 
  
  relevant_factors <- c("Father's Job", "University", "Occupation", "High School")
  
  if (show_only_relevant_factors){
    cregg_output <- cregg_output %>% 
      filter(feature %in% relevant_factors) 
    plot_limits <- c(0.4, 0.6)
    y_label_size <- element_text(size = 8)
    
  } else {
    # need larger limits otherwise some policy items are not shown
    plot_limits <- c(0.3, 0.7)
    
    # smaller text to avoid overlaps on policy items
    y_label_size <- element_text(size = 6)
  } 
  
  
  p <- cregg_output %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits =  plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    theme_bw() +
    theme(axis.text.y = y_label_size) + 
    theme(legend.position = "none")
  
  return(p)
  
}

fix_outcome_var_names_for_plot <- function(){
  good_region_output <- good_region_output %>% 
    mutate(outcome_var = "Good for the Region")
  good_country_output <- good_country_output %>% 
    mutate(outcome_var = "Good for the Country")
}

make_local_national_comparison_plot <- function(cregg_output_region, 
                                               cregg_output_national,
                                               show_only_relevant_factors = TRUE){
  relevant_factors <- c("Father's Job", "University", "Occupation", "High School")
  
  if (show_only_relevant_factors){
    cregg_output_region <- cregg_output_region %>% 
      filter(feature %in% relevant_factors) 
    
    cregg_output_national <- cregg_output_national %>% 
      filter(feature %in% relevant_factors)     
    plot_limits <- c(0.4, 0.6)
    y_label_size <- element_text(size = 8)
  
    } else {
    # need larger limits otherwise some policy items are not shown
    plot_limits <- c(0.3, 0.7)
    
    # smaller text to avoid overlaps on policy items
    y_label_size <- element_text(size = 6)
    }
  
  paper_plot_panel_1 <- cregg_output_region %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits = plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    ggtitle("Good for Region Marginal Means") + 
    theme_bw() +
    theme(axis.text.y = y_label_size) +
    theme(legend.position = "none")
  
  paper_plot_panel_2 <- cregg_output_national %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits = plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    ggtitle("Good for Country Marginal Means") + 
    theme_bw() +
    theme(axis.title.y = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank(),
          legend.position = "none")
  
  
  paper_plot_combined <- paper_plot_panel_1 + paper_plot_panel_2
  
  return(paper_plot_combined)
  
}




path_to_data <- "data/first_survey_formatted.rds"

path_to_figures <- "figures"
paper_figures_folder <- "paper_plots"
appendix_figures_folder <- "appendix_plots"

vote_choice_outcome_plot <- "vote_choice_mm.eps"
comparison_plot <- "region_nation_comparison_mm.eps"

vote_choice_outcome_plot_app <- "vote_choice_mm_appendix.eps"
comparison_plot_app <- "region_nation_comparison_mm_appendix.eps"

survey_data <- readRDS(path_to_data)


good_region_formula <- 
  good_region ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`
good_country_formula <- 
  good_country ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`
vote_choice_formula <- 
  vote_choice ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`


survey_data <- fix_duplicate_level_names(survey_data, good_region_formula)

# cregg is a package that calculates marginal means 
vote_choice_output <- cregg::cj(
  formula = vote_choice_formula,
  data = survey_data,
  id = ~ResponseId,
  estimate = "mm"
)

good_region_output <- cregg::cj(
  formula = good_region_formula,
  data = survey_data,
  id = ~ResponseId,
  estimate = "mm"
)

good_country_output <- cregg::cj(
  formula = good_country_formula,
  data = survey_data,
  id = ~ResponseId,
  estimate = "mm"
)

vote_choice_output <- fix_level_name_in_cregg_output(vote_choice_output)
good_region_output <- fix_level_name_in_cregg_output(good_region_output)
good_country_output <- fix_level_name_in_cregg_output(good_country_output)

fix_outcome_var_names_for_plot()


vote_choice_paper_plot <- make_vote_choice_plot(vote_choice_output)
region_national_comparion_paper_plot <- make_local_national_comparison_plot(
    good_region_output, good_country_output)

vote_choice_appendix_plot <- make_vote_choice_plot(
  vote_choice_output, show_only_relevant_factors = FALSE)

region_national_comparison_appendix_plot <- make_local_national_comparison_plot(
  good_region_output, good_country_output, show_only_relevant_factors = FALSE)




ggsave(
  filename = file.path(path_to_figures, paper_figures_folder, vote_choice_outcome_plot),
  plot = vote_choice_paper_plot,
  height = 8, 
  width = 6,
  units = "in",
  device = "eps")

ggsave(
  filename = file.path(path_to_figures, paper_figures_folder, comparison_plot),
  plot = region_national_comparion_paper_plot,
  height = 9.1, 
  width = 10,
  units = "in",
  device = "eps")

ggsave(
  filename = file.path(path_to_figures, appendix_figures_folder, vote_choice_outcome_plot_app),
  plot = vote_choice_appendix_plot,
  height = 10, 
  width = 8,
  units = "in",
  device = "eps")

ggsave(
  filename = file.path(path_to_figures, appendix_figures_folder, comparison_plot_app),
  plot = region_national_comparison_appendix_plot,
  height = 10, 
  width = 8,
  units = "in",
  device = "eps")




